"""
Various utils for data preprocessing
"""
# pylint: disable=anomalous-backslash-in-string
# pylint: disable=invalid-name
# pylint: disable=missing-function-docstring
# pylint: disable=no-else-return
import os
os.environ["CURL_CA_BUNDLE"] = ""

import pickle
import numpy as np
import torch
import dgl
import networkx as nx
import dgl.data
import scipy.sparse as sp
from sklearn.preprocessing import LabelBinarizer

def hash_list(foo):
    assert isinstance(foo, list)
    my_foo = foo[:]
    my_foo.sort()
    return hash(tuple(my_foo))

class Dataloader():
    def __init__(self, dataset_name, random_seed):
        self.dataset = self.load_data(dataset_name)
        np.random.seed(random_seed)

    ## Basic utils
    def load_data(self, name):
        """load data and basic preprocessing"""
        if "cora" == name:
            dataset = dgl.data.CoraGraphDataset()
            self.get_index(dataset)
        elif "citeseer" == name:
            dataset = dgl.data.CiteseerGraphDataset()
            self.get_index(dataset)
        elif "pubmed" == name:
            dataset = dgl.data.PubmedGraphDataset()
            self.get_index(dataset)
        elif "reddit" == name:
            dataset = dgl.data.RedditDataset()
            self.get_index(dataset)
        elif "a-computer" == name:
            dataset = dgl.data.AmazonCoBuyComputerDataset()
            # self.get_index(dataset)
        elif "a-photo" == name:
            dataset = dgl.data.AmazonCoBuyPhotoDataset()
            # self.get_index(dataset)
        else:
            raise ValueError(f"Invalid data name: {name}")
        return dataset
    
    def get_index(self, dataset):
        """Data preprocessing generate index according to sparse mask, torch env."""
        def mask_to_ind(mask):
            """mask_to_ind"""
            return torch.tensor([i for i, flag in enumerate(mask) if flag])
        dataset.train_idx = mask_to_ind(dataset.train_mask)
        dataset.val_idx = mask_to_ind(dataset.val_mask)
        dataset.test_idx = mask_to_ind(dataset.test_mask)
    
    def get_mask(self, idx, length):
        """
        https://github.com/hwwang55/GCN-LPA/blob/21df23afee0912380ac682b0f80ac244140c33e7/src/data_loader.py#L34
        """
        mask = np.zeros(length)
        mask[idx] = 1
        return np.array(mask, dtype=np.float64)
    
    def split_dataset(self, n_samples):
        """
        https://github.com/hwwang55/GCN-LPA/blob/21df23afee0912380ac682b0f80ac244140c33e7/src/data_loader.py#L40
        """
        val_indices = np.random.choice(list(range(n_samples)), size=int(n_samples * 0.2), replace=False)
        left = set(range(n_samples)) - set(val_indices)
        test_indices = np.random.choice(list(left), size=int(n_samples * 0.2), replace=False)
        train_indices = list(left - set(test_indices))
        train_mask = self.get_mask(train_indices, n_samples)
        eval_mask = self.get_mask(val_indices, n_samples)
        test_mask = self.get_mask(test_indices, n_samples)
        # Yaochen: check the hash to make sure data consistency
        train_hash = hash_list(train_indices)
        val_hash = hash_list(list(val_indices))
        test_hash = hash_list(list(test_indices))
        print(f"Split check: train hash {train_hash}, val hash {val_hash}, test hash {test_hash}")
        return train_mask, eval_mask, test_mask
    
    def to_GCN_LPA(self, save_path):
        def sparse_to_tuple(sparse_matrix):
            """
            https://github.com/hwwang55/GCN-LPA/blob/21df23afee0912380ac682b0f80ac244140c33e7/src/data_loader.py#L25
            """
            if not sp.isspmatrix_coo(sparse_matrix):
                sparse_matrix = sparse_matrix.tocoo()
            indices = np.vstack((sparse_matrix.row, sparse_matrix.col)).transpose()
            values = sparse_matrix.data
            shape = sparse_matrix.shape
            return indices, values, shape
        def normalize_features(features):
            """
            https://github.com/hwwang55/GCN-LPA/blob/21df23afee0912380ac682b0f80ac244140c33e7/src/data_loader.py#L16
            """
            rowsum = np.array(features.sum(1))
            r_inv = np.power(rowsum, -1).flatten()
            r_inv[np.isinf(r_inv)] = 0.0
            r_mat_inv = sp.diags(r_inv)
            features = r_mat_inv.dot(features)
            return features
        # features
        features = sp.coo_matrix(self.dataset[0].ndata['feat'], dtype=np.float64)
        features = normalize_features(features)
        features = sparse_to_tuple(features)
        # labels
        category_labels = self.dataset[0].ndata['label'].numpy()
        one_hot_labels = LabelBinarizer().fit_transform(category_labels)
        # adj
        g = dgl.to_networkx(self.dataset[0])
        g.add_edges_from([(i, i) for i in range(len(g.nodes)) if not g.has_edge(i, i)])  # add self-loops
        adj = sparse_to_tuple(nx.adjacency_matrix(g))
        print('here', adj[0][:10])
        # train, val, test
        train_mask, val_mask, test_mask = self.split_dataset(one_hot_labels.shape[0])
        # print
        print('\n-----------')
        print('features: type: %s, index.type: %s, data_type: %s, features.shape: %s'%(type(features), features[0].dtype, features[1].dtype, features[2]))
        print('labels: type: %s, labels.shape: %s'%(type(one_hot_labels), str(one_hot_labels.shape)))
        print('adj: type: %s, index.type: %s, data_type: %s, adj.shape: %s'%(type(adj), adj[0].dtype, adj[1].dtype, adj[2]))
        print('train_mask: type: %s, train_mask.shape: %s, train_mask_sum: %s'%(type(train_mask), str(train_mask.shape), sum(train_mask)))
        print('val_mask: type: %s, val_mask.shape: %s, val_mask_sum: %s'%(type(val_mask), str(val_mask.shape), sum(val_mask)))
        print('test_mask: type: %s, test_mask.shape: %s, test_mask_sum: %s'%(type(test_mask), str(test_mask.shape), sum(test_mask)))
        print('train_mask head 5 and tail 5:', train_mask[:5], train_mask[-5:])
        print('-----------\n')
        # save
        self.pickle_dump(
            save_path,
            (features, one_hot_labels, adj, train_mask, val_mask, test_mask)
        )
        return 0
    
    def to_NOSMOG(self, save_path):
        def normalize(mx):
            """
            Row-normalize sparse matrix
            https://github.com/meettyj/NOSMOG/blob/7f8ef63e955341f19e81107d1fc6e24f7e5763e3/data_preprocess.py#L29
            """
            rowsum = np.array(mx.sum(1))
            r_inv = np.power(rowsum, -1).flatten()
            r_inv[np.isinf(r_inv)] = 0.0
            r_mat_inv = sp.diags(r_inv)
            mx = r_mat_inv.dot(mx)
            return mx
        def normalize_adj(adj):
            """
            modified based on
            https://github.com/meettyj/NOSMOG/blob/7f8ef63e955341f19e81107d1fc6e24f7e5763e3/data_preprocess.py#L39
            """
            adj = normalize(sp.eye(adj.shape[0]) + adj)
            return adj
        def mask_to_idx(mask):
            """mask_to_ind"""
            return torch.tensor([i for i, flag in enumerate(mask) if flag])
        # g
        g = self.dataset[0]
        adj = g.adjacency_matrix_scipy(fmt='csr')
        adj = normalize_adj(adj)
        adj_sp = adj.tocoo()
        def sparse_to_tuple(sparse_matrix):
            """
            https://github.com/hwwang55/GCN-LPA/blob/21df23afee0912380ac682b0f80ac244140c33e7/src/data_loader.py#L25
            """
            if not sp.isspmatrix_coo(sparse_matrix):
                sparse_matrix = sparse_matrix.tocoo()
            indices = np.vstack((sparse_matrix.row, sparse_matrix.col)).transpose()
            values = sparse_matrix.data
            shape = sparse_matrix.shape
            return indices, values, shape
        g = dgl.graph((adj_sp.row, adj_sp.col))
        print_tuple = sparse_to_tuple(g.adjacency_matrix_scipy(fmt='coo'))
        g.ndata['feat'] = self.dataset[0].ndata['feat']
        # labels
        labels = self.dataset[0].ndata['label']
        # adj
        # train, val, test
        train_mask, val_mask, test_mask = self.split_dataset(labels.shape[0])
        idx_train = mask_to_idx(train_mask)
        idx_val = mask_to_idx(val_mask)
        idx_test = mask_to_idx(test_mask)
        idx_train = torch.LongTensor(idx_train)
        idx_val = torch.LongTensor(idx_val)
        idx_test = torch.LongTensor(idx_test)
        # print
        print('\n-----------')
        print('labels: type: %s, labels.shape: %s'%(type(labels), str(labels.shape)))
        print('adj: type: %s, index.type: %s, data_type: %s, adj.shape: %s'%(type(print_tuple), print_tuple[0].dtype, print_tuple[1].dtype, print_tuple[2]))
        print('train_mask: type: %s, train_mask.shape: %s, train_mask_sum: %s'%(type(train_mask), str(train_mask.shape), sum(train_mask)))
        print('val_mask: type: %s, val_mask.shape: %s, val_mask_sum: %s'%(type(val_mask), str(val_mask.shape), sum(val_mask)))
        print('test_mask: type: %s, test_mask.shape: %s, test_mask_sum: %s'%(type(test_mask), str(test_mask.shape), sum(test_mask)))
        print('train_mask head 5 and tail 5:', train_mask[:5], train_mask[-5:])
        print('-----------\n')

        # save
        self.pickle_dump(
            save_path,
            (g, labels, idx_train, idx_val, idx_test)
        )
        return 0

    def pickle_dump(self, path_out, X, var_name=""):
        with open(path_out, 'wb') as fout:
            pickle.dump(X, fout, protocol=4)
        print(f"Cached {var_name} in {path_out}.")

    def pickle_load(self, path_in, var_name=""):
        with open(path_in, 'rb') as fin:
            res = pickle.load(fin)
        print(f"Loaded {var_name} from {path_in}.")
        return res


if __name__ == '__main__':
    dataloader = Dataloader('pubmed', 234)
    # dataloader.to_GCN_LPA('/home/Ge-zhang/yhu-GNN_Efficient_Inference/GCN_LPA/data/cora/SDMP_data.pickle')
    dataloader.to_NOSMOG('../result/tmp/test.pkl')
    